import numpy as np
import tensorflow as tf
import core.utils as utils
import core.common as common
import core.backbone as backbone
from core.config import cfg


class YOLOV3(object):
    """
    Arguments:
        input_data:图像
    """
    def __init__(self, input_data, trainable):
        #设置是训练还是测试
        self.trainable        = trainable
        #包含类名的dict
        self.classes          = utils.read_class_names(cfg.YOLO.CLASSES)
        #类数
        self.num_class        = len(self.classes)
        #图像按8、16、32缩放
        self.strides          = np.array(cfg.YOLO.STRIDES)
        #获得anchors boxes
        self.anchors          = utils.get_anchors(cfg.YOLO.ANCHORS)
        #在每种缩放比例下anchor box的个数
        self.anchor_per_scale = cfg.YOLO.ANCHOR_PER_SCALE
        #IOU的阈值//TODO:??
        self.iou_loss_thresh  = cfg.YOLO.IOU_LOSS_THRESH
        #上采样方法
        self.upsample_method  = cfg.YOLO.UPSAMPLE_METHOD

        #经过darknet53后的粗、中、细三种特征图
        try:
            self.conv_lbbox, self.conv_mbbox, self.conv_sbbox = self.__build_nework(input_data)
        except:
            raise NotImplementedError("Can not build up yolov3 network!")

        with tf.variable_scope('pred_sbbox'):
            self.pred_sbbox = self.decode(self.conv_sbbox, self.anchors[0], self.strides[0])

        with tf.variable_scope('pred_mbbox'):
            self.pred_mbbox = self.decode(self.conv_mbbox, self.anchors[1], self.strides[1])

        with tf.variable_scope('pred_lbbox'):
            self.pred_lbbox = self.decode(self.conv_lbbox, self.anchors[2], self.strides[2])
    """
    建立YOLO_v3网络
    Args:
        input_data:经过预处理的图像
    Returns:
        经过darknet53后的特征图
    """
    def __build_nework(self, input_data):
        #Darknet53
        route_1, route_2, input_data = backbone.darknet53(input_data, self.trainable)
        #5次卷积操作 DBL*5
        input_data = common.convolutional(input_data, (1, 1, 1024,  512), self.trainable, 'conv52')
        input_data = common.convolutional(input_data, (3, 3,  512, 1024), self.trainable, 'conv53')
        input_data = common.convolutional(input_data, (1, 1, 1024,  512), self.trainable, 'conv54')
        input_data = common.convolutional(input_data, (3, 3,  512, 1024), self.trainable, 'conv55')
        input_data = common.convolutional(input_data, (1, 1, 1024,  512), self.trainable, 'conv56')

        conv_lobj_branch = common.convolutional(input_data, (3, 3, 512, 1024), self.trainable, name='conv_lobj_branch')
        #big bounding box 粗网络的bounding box
        conv_lbbox = common.convolutional(conv_lobj_branch, (1, 1, 1024, 3*(self.num_class + 5)),
                                          trainable=self.trainable, name='conv_lbbox', activate=False, bn=False)

        input_data = common.convolutional(input_data, (1, 1,  512,  256), self.trainable, 'conv57')
        input_data = common.upsample(input_data, name='upsample0', method=self.upsample_method)

        with tf.variable_scope('route_1'):
            input_data = tf.concat([input_data, route_2], axis=-1)

        input_data = common.convolutional(input_data, (1, 1, 768, 256), self.trainable, 'conv58')
        input_data = common.convolutional(input_data, (3, 3, 256, 512), self.trainable, 'conv59')
        input_data = common.convolutional(input_data, (1, 1, 512, 256), self.trainable, 'conv60')
        input_data = common.convolutional(input_data, (3, 3, 256, 512), self.trainable, 'conv61')
        input_data = common.convolutional(input_data, (1, 1, 512, 256), self.trainable, 'conv62')

        conv_mobj_branch = common.convolutional(input_data, (3, 3, 256, 512),  self.trainable, name='conv_mobj_branch' )
        #middle bounding box 中网络的bounding box
        conv_mbbox = common.convolutional(conv_mobj_branch, (1, 1, 512, 3*(self.num_class + 5)),
                                          trainable=self.trainable, name='conv_mbbox', activate=False, bn=False)

        input_data = common.convolutional(input_data, (1, 1, 256, 128), self.trainable, 'conv63')
        input_data = common.upsample(input_data, name='upsample1', method=self.upsample_method)

        with tf.variable_scope('route_2'):
            input_data = tf.concat([input_data, route_1], axis=-1)

        input_data = common.convolutional(input_data, (1, 1, 384, 128), self.trainable, 'conv64')
        input_data = common.convolutional(input_data, (3, 3, 128, 256), self.trainable, 'conv65')
        input_data = common.convolutional(input_data, (1, 1, 256, 128), self.trainable, 'conv66')
        input_data = common.convolutional(input_data, (3, 3, 128, 256), self.trainable, 'conv67')
        input_data = common.convolutional(input_data, (1, 1, 256, 128), self.trainable, 'conv68')

        conv_sobj_branch = common.convolutional(input_data, (3, 3, 128, 256), self.trainable, name='conv_sobj_branch')
        #small bounding box 细网格
        conv_sbbox = common.convolutional(conv_sobj_branch, (1, 1, 256, 3*(self.num_class + 5)),
                                          trainable=self.trainable, name='conv_sbbox', activate=False, bn=False)

        return conv_lbbox, conv_mbbox, conv_sbbox
    """
    Args:
        conv_output:经过yolo_v3网络后的特征图
        anchors:anchor boxes
        stride:图像缩放的比例
    Returns:
        返回实际的bounding box的xywh confidece class_probability张量
    
    """
    def decode(self, conv_output, anchors, stride):
        """
        return tensor of shape [batch_size, output_size, output_size, anchor_per_scale, 5 + num_classes]
               contains (x, y, w, h, score, probability)
        """
        #特征图尺寸
        conv_shape       = tf.shape(conv_output)

        batch_size       = conv_shape[0]

        output_size      = conv_shape[1]
        #anchor_per_scale=3
        anchor_per_scale = len(anchors)

        conv_output = tf.reshape(conv_output, (batch_size, output_size, output_size, anchor_per_scale, 5 + self.num_class))
        #初始xy偏置坐标
        conv_raw_dxdy = conv_output[:, :, :, :, 0:2]
        #初始wh偏置
        conv_raw_dwdh = conv_output[:, :, :, :, 2:4]
        #初始置信度分数
        conv_raw_conf = conv_output[:, :, :, :, 4:5]
        #初始化bbox所属的类
        conv_raw_prob = conv_output[:, :, :, :, 5: ]
        #画网格
        y = tf.tile(tf.range(output_size, dtype=tf.int32)[:, tf.newaxis], [1, output_size])
        x = tf.tile(tf.range(output_size, dtype=tf.int32)[tf.newaxis, :], [output_size, 1])

        # 计算网格左上角的位置
        xy_grid = tf.concat([x[:, :, tf.newaxis], y[:, :, tf.newaxis]], axis=-1)
        xy_grid = tf.tile(xy_grid[tf.newaxis, :, :, tf.newaxis, :], [batch_size, 1, 1, anchor_per_scale, 1])
        xy_grid = tf.cast(xy_grid, tf.float32)

        ## 计算预测框里object的中心位置,对每一个像素点都会做预测
        pred_xy = (tf.sigmoid(conv_raw_dxdy) + xy_grid) * stride
        # 计算预测框里object的宽高
        pred_wh = (tf.exp(conv_raw_dwdh) * anchors) * stride
        pred_xywh = tf.concat([pred_xy, pred_wh], axis=-1)
        # 计算预测框里object的置信度
        pred_conf = tf.sigmoid(conv_raw_conf)
        # 计算预测框里object的类别概率
        pred_prob = tf.sigmoid(conv_raw_prob)

        return tf.concat([pred_xywh, pred_conf, pred_prob], axis=-1)
    """
    
    """
    def focal(self, target, actual, alpha=1, gamma=2):
        focal_loss = alpha * tf.pow(tf.abs(target - actual), gamma)
        return focal_loss
    """
    计算bboxes1和bboxes2的giou
    Arguments:
        boxes1:预测的bbox的中心坐标和宽高=[x,y,w,h]
        boxes2:ground truth的[x,y,w,h]
    Returns:
        giou of boxes1 and boxes2
    """
    def bbox_giou(self, boxes1, boxes2):
        #将[x,y,w,h]转换为[x_min,y_min,x_max,y_max]
        boxes1 = tf.concat([boxes1[..., :2] - boxes1[..., 2:] * 0.5,
                            boxes1[..., :2] + boxes1[..., 2:] * 0.5], axis=-1)
        boxes2 = tf.concat([boxes2[..., :2] - boxes2[..., 2:] * 0.5,
                            boxes2[..., :2] + boxes2[..., 2:] * 0.5], axis=-1)
        #确保x_min<x_max,y_min<y_max
        boxes1 = tf.concat([tf.minimum(boxes1[..., :2], boxes1[..., 2:]),
                            tf.maximum(boxes1[..., :2], boxes1[..., 2:])], axis=-1)
        boxes2 = tf.concat([tf.minimum(boxes2[..., :2], boxes2[..., 2:]),
                            tf.maximum(boxes2[..., :2], boxes2[..., 2:])], axis=-1)
        #bbox1的面积
        boxes1_area = (boxes1[..., 2] - boxes1[..., 0]) * (boxes1[..., 3] - boxes1[..., 1])
        #bbox2的面积
        boxes2_area = (boxes2[..., 2] - boxes2[..., 0]) * (boxes2[..., 3] - boxes2[..., 1])
        #交集坐标
        left_up = tf.maximum(boxes1[..., :2], boxes2[..., :2])
        right_down = tf.minimum(boxes1[..., 2:], boxes2[..., 2:])
        #交集宽高
        inter_section = tf.maximum(right_down - left_up, 0.0)
        #交集面积
        inter_area = inter_section[..., 0] * inter_section[..., 1]
        #并集面积
        union_area = boxes1_area + boxes2_area - inter_area
        #iou值
        iou = inter_area / union_area
        #闭包左上角坐标
        enclose_left_up = tf.minimum(boxes1[..., :2], boxes2[..., :2])
        #闭包右下角坐标
        enclose_right_down = tf.maximum(boxes1[..., 2:], boxes2[..., 2:])
        #闭包宽高
        enclose = tf.maximum(enclose_right_down - enclose_left_up, 0.0)
        #闭包面积
        enclose_area = enclose[..., 0] * enclose[..., 1]
        #giou of bbox1 and bbox2
        giou = iou - 1.0 * (enclose_area - union_area) / enclose_area

        return giou
    """
    计算bboxes1与bboxes的iou值
    Arguments:
        boxes1:预测的bbox的中心坐标和宽高=[x,y,w,h]
        boxes2:ground truth的[x,y,w,h]
    Returns:
        iou of boxes1 and boxes2
    """
    def bbox_iou(self, boxes1, boxes2):
        #bboxes1面积
        boxes1_area = boxes1[..., 2] * boxes1[..., 3]
        # bboxes2面积
        boxes2_area = boxes2[..., 2] * boxes2[..., 3]
        #将[x,y,w,h]转换为[x_min,y_min,x_max,y_max]
        boxes1 = tf.concat([boxes1[..., :2] - boxes1[..., 2:] * 0.5,
                            boxes1[..., :2] + boxes1[..., 2:] * 0.5], axis=-1)
        boxes2 = tf.concat([boxes2[..., :2] - boxes2[..., 2:] * 0.5,
                            boxes2[..., :2] + boxes2[..., 2:] * 0.5], axis=-1)
        #交集的左上角坐标
        left_up = tf.maximum(boxes1[..., :2], boxes2[..., :2])
        #交集的右下角坐标
        right_down = tf.minimum(boxes1[..., 2:], boxes2[..., 2:])
        #交集宽高
        inter_section = tf.maximum(right_down - left_up, 0.0)
        #交集面积
        inter_area = inter_section[..., 0] * inter_section[..., 1]
        #并集面积
        union_area = boxes1_area + boxes2_area - inter_area
        #iou值
        iou = 1.0 * inter_area / union_area

        return iou
    """
    loss层
    Agruments:
        conv:52*52*3(num_class+5) yolo_v3网络输出的不同粗细grid cell的特征图
        pred:预测的bbox[x,y,w,h,conf,prob]
        label:每个grid cell的标签
        bboxes:ground_truth的[x,y,w,h]
        anchors:anchor box的[w,h]
        stride:图像缩放比例[8,16,32]
    Return:
        
    """
    def loss_layer(self, conv, pred, label, bboxes, anchors, stride):

        conv_shape  = tf.shape(conv)
        batch_size  = conv_shape[0]
        output_size = conv_shape[1]
        input_size  = stride * output_size
        conv = tf.reshape(conv, (batch_size, output_size, output_size,
                                 self.anchor_per_scale, 5 + self.num_class))
        conv_raw_conf = conv[:, :, :, :, 4:5]
        conv_raw_prob = conv[:, :, :, :, 5:]

        pred_xywh     = pred[:, :, :, :, 0:4]
        pred_conf     = pred[:, :, :, :, 4:5]

        label_xywh    = label[:, :, :, :, 0:4]

        respond_bbox  = label[:, :, :, :, 4:5]
        label_prob    = label[:, :, :, :, 5:]

        giou = tf.expand_dims(self.bbox_giou(pred_xywh, label_xywh), axis=-1)
        input_size = tf.cast(input_size, tf.float32)
        #边界框的尺寸越小，bbox_loss_scale 的值就越大
        bbox_loss_scale = 2.0 - 1.0 * label_xywh[:, :, :, :, 2:3] * label_xywh[:, :, :, :, 3:4] / (input_size ** 2)
        #仅当ground truth的中心落在grid cell中，respond_bbox为1，否则为0
        #respond_bbox 的意思是如果网格单元中包含物体，那么就会计算边界框损失

        #计算giou损失
        giou_loss = respond_bbox * bbox_loss_scale * (1- giou)


        iou = self.bbox_iou(pred_xywh[:, :, :, :, np.newaxis, :], bboxes[:, np.newaxis, np.newaxis, np.newaxis, :, :])
        #找出与真实框iou最大的预测框
        max_iou = tf.expand_dims(tf.reduce_max(iou, axis=-1), axis=-1)
        #如果最大的iou小于阈值，那么认为该预测框不包含物体，则为背景框
        respond_bgd = (1.0 - respond_bbox) * tf.cast( max_iou < self.iou_loss_thresh, tf.float32 )

        conf_focal = self.focal(respond_bbox, pred_conf)

        #计算置信度损失
        conf_loss = conf_focal * (
                respond_bbox * tf.nn.sigmoid_cross_entropy_with_logits(labels=respond_bbox, logits=conv_raw_conf)
                +
                respond_bgd * tf.nn.sigmoid_cross_entropy_with_logits(labels=respond_bbox, logits=conv_raw_conf)
        )

        #分类误差
        prob_loss = respond_bbox * tf.nn.sigmoid_cross_entropy_with_logits(labels=label_prob, logits=conv_raw_prob)

        giou_loss = tf.reduce_mean(tf.reduce_sum(giou_loss, axis=[1,2,3,4]))
        conf_loss = tf.reduce_mean(tf.reduce_sum(conf_loss, axis=[1,2,3,4]))
        prob_loss = tf.reduce_mean(tf.reduce_sum(prob_loss, axis=[1,2,3,4]))

        return giou_loss, conf_loss, prob_loss

    #//TODO:看懂loss
    """
    计算loss值
    Arguments:
        label_sbbox,label_mbbox,label_lbbox:
        每个像素的标签,[x,y,w,h,confidence,class_id_arr]
        （这个像素就是grid cell，是将原图像除以[8,16,32]后，那么每个像素就是一个gird cell）
        
        true_sbbox,true_mbbox,true_lbbox:
        true_sbbox[i,:4]:细grid cell的第i个ground_truth的[x,y,w,h]
    Return:
        giou_loss:预测选框与ground truth的损失
        conf_loss:置信度损失
        prob_loss:分类损失
    """
    def compute_loss(self, label_sbbox, label_mbbox, label_lbbox, true_sbbox, true_mbbox, true_lbbox):
        #细grid cell损失
        """

        """
        with tf.name_scope('smaller_box_loss'):
            loss_sbbox = self.loss_layer(self.conv_sbbox, self.pred_sbbox, label_sbbox, true_sbbox,
                                         anchors = self.anchors[0], stride = self.strides[0])
        #中grid cell损失
        with tf.name_scope('medium_box_loss'):
            loss_mbbox = self.loss_layer(self.conv_mbbox, self.pred_mbbox, label_mbbox, true_mbbox,
                                         anchors = self.anchors[1], stride = self.strides[1])
        #粗grid cell损失
        with tf.name_scope('bigger_box_loss'):
            loss_lbbox = self.loss_layer(self.conv_lbbox, self.pred_lbbox, label_lbbox, true_lbbox,
                                         anchors = self.anchors[2], stride = self.strides[2])
        #giou_loss
        with tf.name_scope('giou_loss'):
            giou_loss = loss_sbbox[0] + loss_mbbox[0] + loss_lbbox[0]
        #置信度损失
        with tf.name_scope('conf_loss'):
            conf_loss = loss_sbbox[1] + loss_mbbox[1] + loss_lbbox[1]
        #分类损失
        with tf.name_scope('prob_loss'):
            prob_loss = loss_sbbox[2] + loss_mbbox[2] + loss_lbbox[2]

        return giou_loss, conf_loss, prob_loss


